from mimetypes import init
from torch import double
from src.discretizer import Discretizer
from src.abstraction import InterfaceAbstraction, CentroidAbstraction
from src.data_structures import Option, SingleValueSampler, State
from src.search import Search
from src.ll_control import PolicyTrainer as Train
from src.simulator import Simulator
from src.mp import HARP
import os, sys, shutil
import pickle
import numpy as np
from scipy.spatial.kdtree import KDTree

class HARLProblem(object): 

    def __init__(self,robot,mp_robot,problem_config,runid,seed,abstraction_type):
        
        self.i=0
        self.simulator = None
        self.mp_simulator = None
        self.robot = robot
        self.seed = seed
        self.runid = runid
        self.problem_number=-1
        self.mp_robot = mp_robot
        self.problem_config = problem_config
        self.preds_path = os.path.join(self.problem_config["preds_path"],self.problem_config['env_name']+"_"+self.problem_config["robot"]["name"]+".npy")
        self.goal_tolerance = float(self.problem_config["goal_tolerance"])
        self.env_path = os.path.join(self.problem_config["env_mask_path"],self.problem_config['env_name']+".npy")
        self.policy_folder = os.path.join(self.problem_config['policy_folder'], runid)

        # Temp fix to save pytorch models in transit
        self.problem_config['policy_folder'] = self.policy_folder

        self.n_xy_bins = int(self.problem_config["n_xy_bins"])
        self.n_dof_bins = int(self.problem_config["n_dof_bins"])
        self.timeout_per_option = int(self.problem_config["timeout_per_option"])
        self.discretizer = Discretizer(self.robot,self.n_xy_bins,self.n_dof_bins)
        self.mp_discretizer = Discretizer(self.mp_robot,self.n_xy_bins,self.n_dof_bins)
        self.abstraction_type = abstraction_type

        abs_fname = "{}_{}_abs.p".format(self.problem_config['env_name'],abstraction_type)
        if os.path.exists(os.path.join(self.policy_folder,abs_fname)):
            with open(os.path.join(self.policy_folder,abs_fname),"rb") as f:
                self.abstraction = pickle.load(f)
        else:
            self.nn_preds = self.load_pd()
            self.env_mask = self.load_env_mask()
            if abstraction_type == 'centroid':
                self.abstraction = CentroidAbstraction(self.nn_preds,self.discretizer,self.env_mask,self.problem_config["criticality_threshold"])
            elif abstraction_type == 'interface':
                self.abstraction = InterfaceAbstraction(self.nn_preds,self.discretizer,self.env_mask,self.problem_config["criticality_threshold"])
            else:
                raise ValueError("Undefined abstraction type {}. Choices: {}".format(abstraction_type, ", ".join(["centroid", "interface"])))
            with open(os.path.join(self.policy_folder,abs_fname),"wb") as f:
                pickle.dump(self.abstraction,f)
        self.abstraction.plot()
        self.options = self.__initialize_options()
        self.trainer = None

    def load_pd(self):
        pd = np.load(self.preds_path)
        pd = np.squeeze(pd)
        return pd
    
    def load_env_mask(self):
        env_mask = np.squeeze(np.load(self.env_path))[:,:,0]
        return env_mask

    def store_option(self,option):
        option_filename = os.path.join(self.policy_folder,"{}_{}_o{}_{}.pt".format(self.problem_config['robot']['name'],self.problem_config['env_name'],option.src.id,option.dest.id))
        cost_filename = os.path.join(self.policy_folder,"{}_{}_c{}_{}.txt".format(self.problem_config['robot']['name'],self.problem_config['env_name'],option.src.id,option.dest.id))
        guide_filename = os.path.join(self.policy_folder,"{}_{}_g{}_{}.pickle".format(self.problem_config['robot']['name'],self.problem_config['env_name'],option.src.id,option.dest.id))

        shutil.copyfile(option.policy, option_filename)
        with open(cost_filename,"w") as f:
            f.write(str(option.cost))
        with open(guide_filename,"wb") as f:
            pickle.dump({"guide":option.guide,"switch_pt":option.switch_point},f)
        return option_filename, cost_filename, guide_filename
    

    def __initialize_options(self):
        options = {}
        for src in self.abstraction.get_states():
            options[src.id] = {}
            for dest in self.abstraction.get_states():
                options[src.id][dest.id] = Option(src,dest,self.abstraction.estimate_heuristic(src,dest))
                if os.path.exists(os.path.join(self.policy_folder,"{}_{}_o{}_{}.pt".format(self.problem_config['robot']['name'],self.problem_config['env_name'],src,dest))):
                    options[src.id][dest.id].policy = os.path.join(self.policy_folder,"{}_{}_o{}_{}.pt".format(self.problem_config['robot']['name'],self.problem_config['env_name'],src,dest))
                    with open(os.path.join(self.policy_folder, "{}_{}_c{}_{}.txt".format(self.problem_config['robot']['name'],self.problem_config['env_name'],src,dest))) as f:
                        cost = float(f.read())
                    with open(os.path.join(self.policy_folder,"{}_{}_g{}_{}.pickle".format(self.problem_config['robot']['name'],self.problem_config['env_name'],src,dest)),"rb") as f:
                        g = pickle.load(f)
                    options[src.id][dest.id].cost = cost
                    options[src.id][dest.id].guide = g['guide']
                    options[src.id][dest.id].switch_point = g['switch_pt']
        return options
    
    def learn_ll_policy(self, init_config, goal_config, mode=1):
        '''
        Mode 1: goal is a point
        Mode 2: source is a point
        '''
        if mode == 1:
            goal_config = list(goal_config)
            src = src_hl = init_config
            dest_hl = self.abstraction.get_abstract_state(goal_config)
            dest = goal_config
            init_sampler = self.abstraction.get_sampler(src,mode=1)
            term_sampler = SingleValueSampler(goal_config)
            sampler_region = src_hl

            start = src.get_centroid()
            goal = [self.discretizer.get_bin_from_ll(goal_config[i],i) for i in range(len(goal_config)) ]
            src_id = (src_hl.id, ) if not isinstance(src_hl.id, tuple) else src_hl.id
            dest_id = (dest_hl.id, ) if not isinstance(dest_hl.id, tuple) else dest_hl.id
            hl_regions = set(src_id).union(dest_id)

        elif mode == 2: 
            init_config = list(init_config)
            src_hl = self.abstraction.get_abstract_state(init_config)
            src = init_config
            if isinstance(goal_config, list):
                dest = dest_hl = self.abstraction.get_abstract_state(goal_config)
                # hl_regions = set([dest.id])
                hl_regions = None
            else:
                dest = dest_hl = goal_config # here goal_config = first interface region.
                if type(goal_config.id) != int: # Ids are unsubscriptable integers for centroid regions
                    hl_regions = set(goal_config.id)
                else:
                    hl_regions = [goal_config.id]

            term_sampler = self.abstraction.get_sampler(dest_hl,mode=1)
            init_sampler = SingleValueSampler(init_config)
            sampler_region = dest_hl

            # start = [self.discretizer.get_bin_from_ll(init_config[0],0),
            #          self.discretizer.get_bin_from_ll(init_config[1],1)]
            start = [self.discretizer.get_bin_from_ll(init_config[i],i) for i in range(len(init_config))]
            goal  = dest.get_centroid()

        option = Option(src_hl,dest_hl,self.abstraction.estimate_heuristic(src_hl,dest_hl))
        tries = 0
        while tries < 10:
            option_guide = HARP(start, 
                                goal,
                                self.abstraction.get_abstract_state_sampler(option,discretizer=self.discretizer),
                                # self.abstraction.get_abstract_state_sampler(option,discretizer=self.mp_discretizer),
                                self.abstraction.get_uniform_sampler(discretizer=self.discretizer),
                                # self.abstraction.get_uniform_sampler(discretizer=self.mp_discretizer),
                                self.mp_simulator.get_collision_fn(),
                                # self.mp_discretizer,
                                self.discretizer,
                                self.abstraction,
                                simulator=self.mp_simulator,
                                hl_regions=hl_regions).get_mp()
            if option_guide[1] != []:
                break
            tries += 1
        if option_guide[1] == []:
            return None,float("inf"),"timeout",float('inf'),float('inf'),None
        option.guide = KDTree(option_guide[1])
        option.switch_point = option_guide[2]
        self.abstraction.set_evaluation_function(src,dest)
        eval_func = self.abstraction

        trainer = Train(self.problem_config, init_sampler, eval_func, term_sampler,policy_type='point_policy')
        trained_info = trainer.train(init_sampler,
                                  eval_func,
                                  term_sampler,
                                  option.guide,
                                  None,
                                  tblog_prefix=self.tblog_prefix,
                                  seed=self.seed)
        trained_info = list(trained_info)
        trained_info.append(option)
        return trained_info

    def learn_policy_for_option(self,option,mode = 1):
        init_sampler = self.abstraction.get_sampler(option.src,mode=1)
        term_sampler = self.abstraction.get_sampler(option.dest)
        self.abstraction.set_evaluation_function(option.src,option.dest)
        eval_func = self.abstraction

        trainer = Train(self.problem_config, init_sampler, eval_func, term_sampler,policy_type='region_policy')
        return trainer.train(init_sampler,
                                  eval_func,
                                  term_sampler,
                                  option.guide,
                                  option.switch_point,
                                  tblog_prefix=self.tblog_prefix,
                                  seed=self.seed)

    def set_samplers(self,option):
        self.abstraction.set_evaluation_function(option.src,option.dest)

    def solve(self,init_config,goal_config):
        '''
        Log fail states:
        fail_0: Passed the last leg policy training but did not reach the goal in sim
        fail_1: Failed the last leg policy eval
        fail_2: Failed to find a region level policy
        '''
        eval_timesteps = []
        train_timesteps = []
        init_hl_state = self.abstraction.get_abstract_state(init_config)
        goal_hl_state = self.abstraction.get_abstract_state(goal_config)
        init_closest = self.abstraction.get_closest_region(init_config)
        goal_closest = self.abstraction.get_closest_region(goal_config)
        reached_last_hl_state = False
        
        policies_to_exec = []
        hl_solutions, path_has_options = Search.astar_search(init_closest,
                                          goal_closest,
                                          init_hl_state,
                                          goal_hl_state,
                                          self.abstraction.estimate_heuristic,
                                          self.options,
                                          self.abstraction,
                                          self.abstraction.connectivity)

        if path_has_options:
            for solution in hl_solutions:
                print("\n")
                for p in solution:
                    print("{}->{} ".format(p.src.id,p.dest.id))
        #####################################################################
        print("Learning point to region policy")        
        failed = False
        if path_has_options:
            p1,cost,status,eval_ep_len,training_steps,option = self.learn_ll_policy(init_config, hl_solutions[0][0].src,mode=2)
            pass
        else:
            if self.abstraction_type == "interface":
                p1,cost,status,eval_ep_len,training_steps,option = self.learn_ll_policy(init_config, self.abstraction.interfaces[hl_solutions[0]],mode=2)
            else:
                p1,cost,status,eval_ep_len,training_steps,option = self.learn_ll_policy(init_config, self.abstraction.abstract_states[hl_solutions[0]],mode=2)
            p2r_dest = self.abstraction.interfaces[hl_solutions[0]]
            hl_solutions = [[]]
        train_timesteps.append(training_steps)
        eval_timesteps.append(eval_ep_len)

        if status == "timeout":
            print("FAILED: Point to region policy")
            return train_timesteps, eval_timesteps

        policies_to_exec.append((p1,option.guide,option.switch_point, self.abstraction.eval))

        for plan in hl_solutions:
            for i,option in enumerate(plan):
                reached_last_hl_state = False
                if not option.has_policy():
                    print("Learning region to region policy: option from {} to {}".format(option.src.id,option.dest.id))
                    tries=0
                    src_id = (option.src.id, ) if not isinstance(option.src.id, tuple) else option.src.id
                    dest_id = (option.dest.id, ) if not isinstance(option.dest.id, tuple) else option.dest.id
                    while tries < 10:
                        option_guide = HARP(option.src.get_centroid(),
                                            option.dest.get_centroid(),
                                            # self.abstraction.get_critical_region_sampler(option,discretizer=self.mp_discretizer),
                                            self.abstraction.get_critical_region_sampler(option,discretizer=self.discretizer),
                                            # self.abstraction.get_abstract_state_sampler(option,discretizer=self.mp_discretizer),
                                            self.abstraction.get_abstract_state_sampler(option,discretizer=self.discretizer),
                                            self.mp_simulator.get_collision_fn(),
                                            # self.mp_discretizer,
                                            self.discretizer,
                                            self.abstraction,
                                            simulator=self.mp_simulator, 
                                            max_time=self.problem_config["HARP"]["timeout"],
                                            hl_regions = set(src_id).union(dest_id)).get_mp()

                        if option_guide[1] != []:
                            break
                        tries += 1

                    if not option_guide[0]:
                        break
                    option.guide = KDTree(option_guide[1])
                    option.switch_point = option_guide[2]
                    print("Learning option {}->{}".format(option.src.id, option.dest.id))
                    policy, cost, status,eval_ep_len,training_steps = self.learn_policy_for_option(option)
                    train_timesteps.append(training_steps)
                    eval_timesteps.append(eval_ep_len)
                    if status == "okay":
                        option.policy = policy
                        option.cost = cost
                        option.policy, _, _ = self.store_option(option)
                        policies_to_exec.append((option.policy,option.guide,option.switch_point,self.abstraction.eval))
                    if status == "timeout":
                        # failed = True
                        print("FAILED: Region to region policy from {}->{}".format(option.src.id, option.dest.id))
                        break
                else:
                    print("Reusing option {}->{}".format(option.src.id, option.dest.id))
                    self.set_samplers(option)
                    policies_to_exec.append((option.policy,option.guide,option.switch_point,self.abstraction.eval))

        # failed = False

        if not failed:
            dest = hl_solutions[0][-1].dest if path_has_options else p2r_dest
            p2, cost, status,eval_ep_len,training_steps,option = self.learn_ll_policy(dest,goal_config,mode=1)
            train_timesteps.append(training_steps)
            eval_timesteps.append(eval_ep_len)
            if status == "okay":
                policies_to_exec.append((p2,option.guide,option.switch_point, self.abstraction.eval))
                with open(os.path.join(self.policy_log_dir, "p{}.pickle".format(self.problem_number)),'wb') as f:
                    pickle.dump(policies_to_exec, f)
            else:
                print("FAILED: Region to point policy")
        return train_timesteps, eval_timesteps

    def generate_and_test(self):
        experiment_fname = '{}_{}.pickle'.format(self.problem_config['robot']['name'], self.problem_config['env_name'])
        experiment = os.path.join(self.problem_config['experiments_path'],experiment_fname)
        with open(experiment, 'rb') as f:
            load_locations = pickle.load(f)
        self.action_log_dir = os.path.join(self.problem_config['action_logs'],
                                           self.problem_config['robot']['name'],
                                           self.problem_config['env_name'],
                                           self.runid)
        self.policy_log_dir = os.path.join(self.problem_config['policy_logs'],
                                           self.problem_config['robot']['name'],
                                           self.problem_config['env_name'],
                                           self.runid)
        cross_task_trainsteps = []
        cross_task_evalsteps = []
        for i in range(self.problem_config["num_test_problems"]):               # TODO remove the slice and make th
            init = load_locations['init'][i][:3] 
            goal = load_locations['goal'][i][:3] 
            
            # for k in range(len(init) - 1):
            #     init[k] = init[k] * 3 
            #     goal[k] = goal[k] * 3
            init_sampler = SingleValueSampler(init)
            init_sampler = SingleValueSampler(init)
            term_sampler = SingleValueSampler(goal)
            s1, s2 = self.abstraction.get_abstract_state(init), self.abstraction.get_abstract_state(goal)
            # eval_func = self.abstraction.get_evaluation_function(s1,s2)
            # self.abstraction.set_evaluation_function(s1,s2)
            eval_func = self.abstraction
            self.tblog_prefix = "_".join([self.robot.name,self.problem_config['env_name']])+"/{}/{}".format(self.runid,i+1)
            print("Running problem #{0}".format(i+1))
            self.mp_simulator = Simulator(init_sampler,
                                            eval_func,
                                            self.problem_config,
                                            term_sampler,
                                            mp=True,
                                            seed=self.seed)

            self.problem_number = i+1
            train_steps, eval_steps = self.solve(init,goal)

            eval_file_dir=os.path.join(self.policy_log_dir,str(i+1))
            cross_task_trainsteps.append(train_steps)
            cross_task_evalsteps.append(eval_steps)
            os.makedirs(eval_file_dir,exist_ok=True)
            np.save(os.path.join(eval_file_dir,'train_steps.npy'),np.array([train_steps]))
            np.savetxt(os.path.join(eval_file_dir,'train_steps.txt'),np.array([train_steps]))
            np.save(os.path.join(eval_file_dir,'eval_steps.npy'),np.array(eval_steps))
            np.savetxt(os.path.join(eval_file_dir,'eval_steps.txt'),np.array(eval_steps))

        np.save(os.path.join(self.policy_log_dir,'crosstask_train_steps.npy'),cross_task_trainsteps)
        np.save(os.path.join(self.policy_log_dir,'crosstask_eval_steps.npy'),cross_task_evalsteps)
        return None
